/*
 * Copyright (2021) The Delta Lake Project Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.delta

import java.io.File

import org.apache.spark.sql.delta.actions.{Action, AddFile, FileAction, SingleAction}
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{QueryTest, SparkSession}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

abstract class EvolvabilitySuiteBase extends QueryTest with SharedSparkSession
    with SQLTestUtils {
  import testImplicits._

  protected def testEvolvability(tablePath: String): Unit = {
    // Check we can load everything from a log checkpoint
    val deltaLog = DeltaLog.forTable(spark, new Path(tablePath))
    val path = deltaLog.dataPath.toString
    checkDatasetUnorderly(
      spark.read.format("delta").load(path).select("id", "value").as[(Int, String)],
      4 -> "d", 5 -> "e", 6 -> "f")
    assert(deltaLog.snapshot.metadata.schema === StructType.fromDDL("id INT, value STRING"))
    assert(deltaLog.snapshot.metadata.partitionSchema === StructType.fromDDL("id INT"))

    // Check we can load CheckpointMetaData
    assert(deltaLog.lastCheckpoint.get.version === 3)
    assert(deltaLog.lastCheckpoint.get.size === 6L)

    // Check we can parse all `Action`s in delta files. It doesn't check correctness.
    deltaLog.getChanges(0L).toList.map(_._2.toList)
  }

  /**
   * This tests the evolution of the schema at delta file and checkpoint file.
   * Operations on the Delta table shouldn't fail when there is an unknown column
   * in delta file and checkpoint file.
   *
   * Table Schema: StructType(StructField("key", StringType), StructField("value", StringType))
   * Overwritten Delta file: {"some_new_feature":{"a":1}}
   * Overwritten checkpoint file with a new column called `unknown` with boolean type.
   *
   * The delta file and checkpoint file with an unknown column are generated by
   * `EvolvabilitySuiteBase.generateTransactionLogWithExtraColumn()`.
   */
  protected def testLogSchemaEvolvability(operation: String => Unit): Unit = {
    withTempDir { tempDir =>
      // copy the existing dir to the temp data dir.
      FileUtils.copyDirectory(
        new File("src/test/resources/delta/transaction_log_schema_evolvability"), tempDir)
      DeltaLog.clearCache()
      operation(tempDir.getAbsolutePath)
    }
  }
}


// scalastyle:off
/***
 * A tool to generate data and transaction log for evolvability tests.
 *
 * Here are the steps to generate data.
 *
 * 1. Update `EvolvabilitySuite.generateData` if there are new [[Action]] types.
 * 2. Change the following command with the right path and run it. Note: the working directory is "[delta_project_root]".
 *
 * scalastyle:off
 * ```
 * build/sbt "core/test:runMain org.apache.spark.sql.delta.EvolvabilitySuite src/test/resources/delta/delta-0.1.0 generateData"
 * ```
 *
 * You can also use this tool to generate DeltaLog that contains a checkpoint a json log with a new column.
 *
 * scalastyle:off
 * ```
 * build/sbt "core/test:runMain org.apache.spark.sql.delta.EvolvabilitySuite /path/src/test/resources/delta/transaction_log_schema_evolvability generateTransactionLogWithExtraColumn"
 * ```
 */
// scalastyle:on
object EvolvabilitySuiteBase {

  def generateData(
      spark: SparkSession,
      path: String,
      tblProps: Map[DeltaConfig[_], String] = Map.empty): Unit = {
    import spark.implicits._
    implicit val s = spark.sqlContext

    Seq(1, 2, 3).toDF().write.format("delta").save(path)
    if (tblProps.nonEmpty) {
      val tblPropsStr = tblProps.map { case (k, v) => s"'${k.key}' = '$v'" }.mkString(", ")
      spark.sql(s"CREATE TABLE test USING DELTA LOCATION '$path'")
      spark.sql(s"ALTER TABLE test SET TBLPROPERTIES($tblPropsStr)")
    }
    Seq(1, 2, 3).toDF().write.format("delta").mode("append").save(path)
    Seq(1, 2, 3).toDF().write.format("delta").mode("overwrite").save(path)

    val checkpoint = Utils.createTempDir().toString
    val data = MemoryStream[Int]
    data.addData(1, 2, 3)
    val stream = data.toDF()
      .writeStream
      .format("delta")
      .option("checkpointLocation", checkpoint)
      .start(path)
    stream.processAllAvailable()
    stream.stop()

    DeltaLog.forTable(spark, path).checkpoint()
  }

  /** Validate the generated data contains all [[Action]] types */
  def validateData(spark: SparkSession, path: String): Unit = {
    import org.apache.spark.sql.delta.util.FileNames._
    import scala.reflect.runtime.{universe => ru}
    import spark.implicits._

    val mirror = ru.runtimeMirror(this.getClass.getClassLoader)

    val tpe = ru.typeOf[Action]
    val clazz = tpe.typeSymbol.asClass
    assert(clazz.isSealed, s"${classOf[Action]} must be sealed")

    val deltaLog = DeltaLog.forTable(spark, new Path(path))
    val deltas = 0L to deltaLog.snapshot.version
    val deltaFiles = deltas.map(deltaFile(deltaLog.logPath, _)).map(_.toString)
    val actionsTypesInLog =
      spark.read.schema(Action.logSchema).json(deltaFiles: _*)
        .as[SingleAction]
        .collect()
        .map(_.unwrap.getClass.asInstanceOf[Class[_]])
        .toSet

    val allActionTypes =
      clazz.knownDirectSubclasses
        .flatMap {
          case t if t == ru.typeOf[FileAction].typeSymbol => t.asClass.knownDirectSubclasses
          case t => Set(t)
        }
        .map(t => mirror.runtimeClass(t.asClass))

    val missingTypes = allActionTypes -- actionsTypesInLog
    val unknownTypes = actionsTypesInLog -- allActionTypes
    assert(
      missingTypes.isEmpty,
      s"missing types: $missingTypes. " +
        "Please update EvolveabilitySuite.generateData to include them in the log.")
    assert(
      unknownTypes.isEmpty,
      s"unknown types: $unknownTypes. " +
        s"Please make sure they inherit ${classOf[Action]} or ${classOf[FileAction]} directly.")
  }

  /** Generate the transaction log with extra column in checkpoint and json. */
  def generateTransactionLogWithExtraColumn(spark: SparkSession, path: String): Unit = {
    import spark.implicits._
    implicit val s = spark.sqlContext

    val absPath = new File(path).getAbsolutePath

    (1 until 10).map(num => (num, num)).toDF("key", "value").write.format("delta").save(path)

    // Enable struct-only stats
    spark.sql(s"ALTER TABLE delta.`$absPath` " +
      s"SET TBLPROPERTIES (delta.checkpoint.writeStatsAsStruct = true, " +
      "delta.checkpoint.writeStatsAsJson = false)")

    (1 until 10).map(num => (num, num)).toDF("key", "value").write
      .format("delta").mode("overwrite").save(path)

    val deltaLog = DeltaLog.forTable(spark, new Path(path))

    deltaLog.checkpoint()

    // Create an incomplete checkpoint without the action and overwrite the
    // original checkpoint
    val checkpointPath = FileNames.checkpointFileSingular(deltaLog.logPath,
      deltaLog.snapshot.version)
    val tmpCheckpoint = Utils.createTempDir()
    val checkpointDataWithNewCol = spark.read.parquet(checkpointPath.toString)
      .withColumn("unknown", lit(true))

    // Keep the add files and also filter by the additional condition
    checkpointDataWithNewCol.coalesce(1).write
      .mode("overwrite").parquet(tmpCheckpoint.toString)
    val writtenCheckpoint =
      tmpCheckpoint.listFiles().toSeq.filter(_.getName.startsWith("part")).head
    val checkpointFile = new File(checkpointPath.toUri)
    new File(deltaLog.logPath.toUri).listFiles().toSeq.foreach { file =>
      if (file.getName.startsWith(".0")) {
        // we need to delete checksum files,
        // otherwise trying to replace our incomplete
        // checkpoint file fails due to the LocalFileSystem's checksum checks.
        require(file.delete(), "Failed to delete checksum file")
      }
    }
    require(checkpointFile.delete(), "Failed to delete old checkpoint")
    require(writtenCheckpoint.renameTo(checkpointFile),
      "Failed to rename corrupt checkpoint")

    (1 until 10).map(num => (num, num)).toDF("key", "value").write
      .format("delta").mode("append").save(path)

    // Shouldn't fail here
    deltaLog.update()

    val version = deltaLog.snapshot.version
    // We want to have a delta log with a new column after a checkpoint, to test out operations
    // against both checkpoint with unknown column and delta log with unkown column.

    // manually remove AddFile in the previous commit and append a new column.
    val records = deltaLog.store.read(
      FileNames.deltaFile(deltaLog.logPath, version),
      deltaLog.newDeltaHadoopConf())
    val actions = records.map(Action.fromJson).filter(action => action.isInstanceOf[AddFile])
      .map { action => action.asInstanceOf[AddFile].remove}
      .toIterator
    val recordsWithNewAction = actions.map(_.json) ++ Iterator("""{"some_new_action":{"a":1}}""")
    deltaLog.store.write(
      FileNames.deltaFile(deltaLog.logPath, version + 1),
      recordsWithNewAction,
      overwrite = false,
      deltaLog.newDeltaHadoopConf())

    // manually add those files back and add a unknown field to it.
    val newRecords = records.map{ record =>
      val recordMap = JsonUtils.fromJson[Map[String, Any]](record)
      val newRecordMap = if (recordMap.contains("add")) {
        // add a unknown column inside action fields.
        val actionFields = recordMap("add").asInstanceOf[Map[String, Any]] +
          ("some_new_column_in_add_action" -> 1)
        recordMap + ("add" -> actionFields)
      } else recordMap
      // add a unknown column outside action fields.
      JsonUtils.toJson(newRecordMap + ("some_new_action_alongside_add_action" -> ("a" -> "1")))
    }.toIterator
    deltaLog.store.write(
      FileNames.deltaFile(deltaLog.logPath, version + 2),
      newRecords,
      overwrite = false,
      deltaLog.newDeltaHadoopConf())

    // Shouldn't fail here
    deltaLog.update()

    DeltaLog.clearCache()
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val path = new File(args(0))
    if (path.exists()) {
      // Don't delete automatically in case the user types a wrong path.
      // scalastyle:off throwerror
      throw new AssertionError(s"${path.getCanonicalPath} exists. Please delete it and retry.")
      // scalastyle:on throwerror
    }
    args(1) match {
      case "generateData" =>
        generateData(spark, path.toString)
        validateData(spark, path.toString)
      case "generateTransactionLogWithExtraColumn" =>
        generateTransactionLogWithExtraColumn(spark, path.toString)
      case _ =>
        throw new RuntimeException("Unrecognized (or omitted) argument. " +
          "Please try again (no data generated).")
    }
  }
}
